{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "06f7f570",
   "metadata": {},
   "source": [
    "# Stix Dispersion Solver\n",
    "\n",
    "[stix]: ../../api_static/plasmapy.dispersion.analytical.stix_.rst\n",
    "[bellan2012]: https://doi.org/10.1029/2012ja017856\n",
    "[stix1992]: https://link.springer.com/book/9780883188590\n",
    "\n",
    "This notebook details the functionality of the [stix()][stix] function. This is an analytical solution of equation 8 in [Bellan 2012][bellan2012], the function is defined by [Stix 1992][stix1992] in §1.2 to be:\n",
    "\n",
    "$$\n",
    "        (S \\sin^2(θ) + P \\cos^2(θ)) \\left ( \\frac{ck}{ω} \\right)^4\n",
    "            - [\n",
    "                RL \\sin^2(θ) + PS (1 + \\cos^2(θ))\n",
    "            ] \\left ( \\frac{ck}{ω} \\right)^2 + PRL = 0\n",
    "$$\n",
    "\n",
    "where,\n",
    "\n",
    "$$\n",
    "        \\mathbf{B}_0 = B_0 \\mathbf{\\hat{z}}\n",
    "         \\cos θ = \\frac{k_z}{k} \\\\\n",
    "        \\mathbf{k} = k_{\\rm x} \\hat{x} + k_{\\rm z} \\hat{z}\n",
    "$$\n",
    "\n",
    "$$\n",
    "        S = 1 - \\sum_s \\frac{ω^2_{p,s}}{ω^2 -\n",
    "            ω^2_{c,s}}\\hspace{2.5cm}\n",
    "        P = 1 - \\sum_s \\frac{ω^2_{p,s}}{ω^2}\\hspace{2.5cm}\n",
    "        D = \\sum_s\n",
    "            \\frac{ω_{c,s}}{ω}\n",
    "            \\frac{ω^2_{p,s}}{ω^2 - ω_{c,s}^2}\n",
    "$$\n",
    "\n",
    "$$\n",
    "        R = S + D \\hspace{1cm} L = S - D\n",
    "$$\n",
    "\n",
    "$ω$ is the wave frequency, $k$ is the wavenumber, $θ$ is the wave propagation angle with respect to the background magnetic field $\\mathbf{B}_0$, $s$ corresponds to plasma species, $ω_{p,s}$ is the plasma frequency of species and $ω_{c,s}$ is the gyrofrequency of species $s$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "013394df-d0eb-4be1-8c9f-b1d2f484ae7d",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "Note\n",
    "\n",
    "The derivation of this dispersion relation assumed:\n",
    "\n",
    " * zero temperature for all plasma species ($T_s=0$)\n",
    " * quasineutrality\n",
    " * a uniform background magnetic field $\\mathbf{B_0} = B_0 \\mathbf{\\hat{z}}$\n",
    " * no D.C. electric field $\\mathbf{E_0}=0$\n",
    " * zero-order quantities for all plasma parameters (densities, electric-field, magnetic field, particle speeds, etc.) are constant in time and space\n",
    " * first-order perturbations in plasma parameters vary like $\\sim e^{\\left [ i (\\textbf{k}\\cdot\\textbf{r} - \\omega t)\\right ]}$\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bb0b482-b112-4652-a964-10952d2f0565",
   "metadata": {},
   "source": [
    "Due to the cold plasma assumption, this equation is valid for all $ω$ and $k$ given $\\frac{ω}{k_z} ≫ v_{Th}$ for all thermal speeds $v_{Th}$ of all plasma species and $k_x r_L ≪ 1$ for all gyroradii $r_L$ of all plasma species. The relation predicts $k → 0$ when any one of P, R or L vanish (cutoffs) and $k → ∞$ for perpendicular propagation during wave resonance $S → 0$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4196cd1-3f7c-4e4d-b827-7836974c002c",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Wave Normal to the Surface](#Wave-normal-to-the-surface)\n",
    "2. [Comparison with Bellan](#Comparison-with-bellan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c39b442",
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "\n",
    "import astropy.units as u\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import scipy\n",
    "from astropy.constants.si import c\n",
    "\n",
    "from plasmapy.dispersion.analytical.stix_ import stix\n",
    "from plasmapy.dispersion.analytical.two_fluid_ import two_fluid\n",
    "from plasmapy.formulary import speeds\n",
    "from plasmapy.particles import Particle\n",
    "\n",
    "plt.rcParams[\"figure.figsize\"] = [10.5, 0.56 * 10.5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5d0b2fc",
   "metadata": {},
   "source": [
    "## Wave Normal to the Surface\n",
    "\n",
    "To calculate the normal surface waves propagating through a magnetized uniform cold plasma. The wave which is normal to the surface, is the locus of the phase velocity $\\textbf{v}_{phase} = \\frac{ω}{k} \\, \\hat{k}$ where $\\hat{k} = \\frac{\\textbf{k}}{k}$. The equation for the wave normal surface can be derived via the prior equations, resulting in the form of\n",
    "\n",
    "$$\n",
    "    A u^4 + B u^2 + C = 0\n",
    "$$\n",
    "\n",
    "where $u = \\frac{ω}{ck}$. To begin we define the required parameters to compute the wave numbers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab1a4b91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define input parameters\n",
    "inputs_1 = {\n",
    "    \"theta\": np.linspace(0, np.pi, 50) * u.rad,\n",
    "    \"ions\": Particle(\"p\"),\n",
    "    \"n_i\": 1e12 * u.cm**-3,\n",
    "    \"B\": 0.43463483142776164 * u.T,\n",
    "    \"w\": 41632.94534008216 * u.rad / u.s,\n",
    "}\n",
    "\n",
    "# define a meshgrid based on the number of theta values\n",
    "omegas, thetas = np.meshgrid(\n",
    "    inputs_1[\"w\"].value, inputs_1[\"theta\"].value, indexing=\"ij\"\n",
    ")\n",
    "omegas = np.dstack((omegas,) * 4).squeeze()\n",
    "thetas = np.dstack((thetas,) * 4).squeeze()\n",
    "\n",
    "# compute k values\n",
    "k = stix(**inputs_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57d74632",
   "metadata": {},
   "source": [
    "[Quantity]: https://docs.astropy.org/en/stable/api/astropy.units.Quantity.html#astropy.units.Quantity\n",
    "\n",
    "The computed wavenumbers in units (rad/m) are returned in a dictionary (shape $N × M × 4$), with the keys representing $θ$ and the values (instances of Astropy [Quantity]) being the wavenumbers. The first dimension maps to the $w$ array, the second dimension maps to the $θ$ array, and the third dimension maps to the four roots of the Stix polynomial.\n",
    "\n",
    "* $k[0]$ is the square root of the positive quadratic solution\n",
    "* $k[1] = -k[0]$\n",
    "* $k[2]$ is the square root of the negative quadratic solution\n",
    "* $k[3] = -k[2]$\n",
    "\n",
    "Below the values for $u_x$ and $u_z$ are calculated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "761f2782",
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate ux and uz\n",
    "\n",
    "u_v = {}\n",
    "\n",
    "mask = np.imag(k) == 0\n",
    "\n",
    "va_1 = speeds.va_(inputs_1[\"B\"], inputs_1[\"n_i\"], ion=inputs_1[\"ions\"])\n",
    "for arr in k:\n",
    "    val = 0\n",
    "    for item in arr:\n",
    "        val = val + item**2\n",
    "    norm = (np.sqrt(val) * va_1 / inputs_1[\"w\"]).value ** 2\n",
    "    u_v = {\n",
    "        \"ux\": norm * omegas[mask] * np.sin(thetas[mask]) / (k.value[mask] * c.value),\n",
    "        \"uz\": norm * omegas[mask] * np.cos(thetas[mask]) / (k.value[mask] * c.value),\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4084ffc7",
   "metadata": {},
   "source": [
    "Let's plot the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcb06d10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the results\n",
    "\n",
    "fs = 14  # default font size\n",
    "figwidth, figheight = plt.rcParams[\"figure.figsize\"]\n",
    "figheight = 1.6 * figheight\n",
    "fig = plt.figure(figsize=[figwidth, figheight])\n",
    "\n",
    "plt.scatter(\n",
    "    u_v[\"ux\"],\n",
    "    u_v[\"uz\"],\n",
    "    label=\"Stix: Fig. 1-1\",\n",
    ")\n",
    "\n",
    "# adjust axes\n",
    "plt.xlabel(r\"$u_x$\", fontsize=fs)\n",
    "plt.ylabel(r\"$u_z$\", fontsize=fs)\n",
    "\n",
    "pad = 1.25\n",
    "plt.ylim(min(u_v[\"uz\"]) * pad, max(u_v[\"uz\"]) * pad)\n",
    "plt.xlim(min(u_v[\"ux\"]) * pad, max(u_v[\"ux\"]) * pad)\n",
    "\n",
    "plt.tick_params(\n",
    "    which=\"both\",\n",
    "    direction=\"in\",\n",
    "    labelsize=fs,\n",
    "    right=True,\n",
    "    length=5,\n",
    ")\n",
    "\n",
    "# plot caption\n",
    "txt = (\n",
    "    \"Fig. 1-1: Waves normal surfaces, parameters represent \\nthe shear Alfvén wave \"\n",
    "    \"and the compressional Alfvén wave. \\nThe zero-order magnetic field is directed along \"\n",
    "    \" the z-axis.\"\n",
    ")\n",
    "\n",
    "plt.figtext(0.25, -0.04, txt, ha=\"left\", fontsize=fs)\n",
    "plt.legend(loc=\"upper left\", markerscale=1, fontsize=fs)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68791198",
   "metadata": {},
   "source": [
    "[stix1992]: https://link.springer.com/book/9780883188590\n",
    "\n",
    "Here we can define the parameters for all the plots in [Stix 1992][stix1992] and then reproduce them in the same fashion."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "198fe5d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define inputs\n",
    "inputs_2 = {\n",
    "    \"theta\": np.linspace(0, np.pi, 100) * u.rad,\n",
    "    \"ions\": Particle(\"p\"),\n",
    "    \"n_i\": 1e12 * u.cm**-3,\n",
    "    \"B\": 0.434 * u.T,\n",
    "    \"w\": (37125810) * u.rad / u.s,\n",
    "}\n",
    "\n",
    "inputs_3 = {\n",
    "    \"theta\": np.linspace(0, np.pi, 100) * u.rad,\n",
    "    \"ions\": Particle(\"p\"),\n",
    "    \"n_i\": 1e12 * u.cm**-3,\n",
    "    \"B\": 0.434534 * u.T,\n",
    "    \"w\": (2 * 10**10) * u.rad / u.s,\n",
    "}\n",
    "\n",
    "inputs_4 = {\n",
    "    \"theta\": np.linspace(0, np.pi, 100) * u.rad,\n",
    "    \"ions\": Particle(\"p\"),\n",
    "    \"n_i\": 1e12 * u.cm**-3,\n",
    "    \"B\": 0.434600 * u.T,\n",
    "    \"w\": (54 * 10**9) * u.rad / u.s,\n",
    "}\n",
    "\n",
    "inputs_5 = {\n",
    "    \"theta\": np.linspace(0, np.pi, 100) * u.rad,\n",
    "    \"ions\": Particle(\"p\"),\n",
    "    \"n_i\": 1e12 * u.cm**-3,\n",
    "    \"B\": 0.434634 * u.T,\n",
    "    \"w\": (58 * 10**9) * u.rad / u.s,\n",
    "}\n",
    "\n",
    "# define a list of all inputs\n",
    "stix_inputs = [inputs_1, inputs_2, inputs_3, inputs_4, inputs_5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1854c09",
   "metadata": {},
   "source": [
    "Following on, the same method implemented on the first set of input parameters can be implemented on the rest. Afterwards, the result for all inputs can be plotted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a934bae",
   "metadata": {},
   "outputs": [],
   "source": [
    "stix_plt = {}\n",
    "\n",
    "ux = {}\n",
    "uz = {}\n",
    "\n",
    "for i in range(len(stix_inputs)):\n",
    "    stix_plt[i] = {}\n",
    "\n",
    "\n",
    "for i, inpt in enumerate(stix_inputs):\n",
    "    omegas, thetas = np.meshgrid(inpt[\"w\"].value, inpt[\"theta\"].value, indexing=\"ij\")\n",
    "    omegas = np.dstack((omegas,) * 4).squeeze()\n",
    "    thetas = np.dstack((thetas,) * 4).squeeze()\n",
    "\n",
    "    k = stix(**inpt)\n",
    "\n",
    "    mask = np.imag(k) == 0\n",
    "\n",
    "    va = speeds.va_(inpt[\"B\"], inpt[\"n_i\"], ion=inpt[\"ions\"])\n",
    "\n",
    "    for arr in k:\n",
    "        val = 0\n",
    "        for item in arr:\n",
    "            val = val + item**2\n",
    "        norm = (np.sqrt(val) * va / inpt[\"w\"]).value ** 2\n",
    "        stix_plt[i] = {\n",
    "            \"ux\": norm\n",
    "            * omegas[mask]\n",
    "            * np.sin(thetas[mask])\n",
    "            / (k.value[mask] * c.value),\n",
    "            \"uz\": norm\n",
    "            * omegas[mask]\n",
    "            * np.cos(thetas[mask])\n",
    "            / (k.value[mask] * c.value),\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1df145c",
   "metadata": {},
   "source": [
    "Plot the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb0343d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create figure\n",
    "fig, axs = plt.subplots(2, 2, figsize=[figwidth, figheight])\n",
    "\n",
    "\n",
    "for i in range(2):\n",
    "    for j in range(2):\n",
    "        axs[i, j].scatter(\n",
    "            stix_plt[i + 2 * j + 1][\"ux\"], stix_plt[i + 2 * j + 1][\"uz\"], label=\"dfwd\"\n",
    "        )\n",
    "        axs[i, j].set_title(\"Stix: Fig. 1-\" + str(i + 2 * j + 2), fontsize=fs)\n",
    "\n",
    "        # adjust axes\n",
    "        axs[i, j].set(\n",
    "            ylabel=r\"$u_z$\",\n",
    "            xlabel=r\"$u_z$\",\n",
    "        )\n",
    "\n",
    "        pad = 1.25\n",
    "        axs[i, j].set_ylim(\n",
    "            min(stix_plt[i + 2 * j + 1][\"uz\"]) * pad,\n",
    "            max(stix_plt[i + 2 * j + 1][\"uz\"]) * pad,\n",
    "        )\n",
    "        axs[i, j].set_xlim(\n",
    "            min(stix_plt[i + 2 * j + 1][\"ux\"]) * pad,\n",
    "            max(stix_plt[i + 2 * j + 1][\"ux\"]) * pad,\n",
    "        )\n",
    "\n",
    "        axs[i, j].tick_params(\n",
    "            which=\"both\",\n",
    "            direction=\"in\",\n",
    "            labelsize=fs,\n",
    "            right=True,\n",
    "            length=5,\n",
    "        )\n",
    "\n",
    "\n",
    "# plot caption\n",
    "txt = \"Wave normal surface reproduced from Stix.\"\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.figtext(0.35, -0.02, txt, ha=\"left\", fontsize=fs)\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2e02554",
   "metadata": {},
   "source": [
    "[stix]: ../../api_static/plasmapy.dispersion.analytical.stix_.rst\n",
    "[bellan2012]: https://doi.org/10.1029/2012ja017856\n",
    "\n",
    "## Comparison with Bellan\n",
    "\n",
    "Below we run a comparison between the solution provided in [Bellan 2012][bellan2012] and our own solutions computed from [stix()][stix]. To begin we first create a function that reproduces the Bellan plot.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cb63718",
   "metadata": {},
   "outputs": [],
   "source": [
    "def norm_bellan_plot(**inputs):\n",
    "    \"\"\"Reproduce plot of Bellan dispersion relation.\"\"\"\n",
    "    w = inputs[\"w\"]\n",
    "    k = inputs[\"k\"]\n",
    "    theta = inputs[\"theta\"]\n",
    "\n",
    "    if w.shape == k.shape or w.size == 1 or k.size == 1:\n",
    "        pass\n",
    "    elif w.ndim > 2 or k.ndim > 2 or k.shape[0] != w.shape[0]:\n",
    "        raise ValueError\n",
    "    elif k.ndim > w.ndim:\n",
    "        w = np.repeat(w[..., np.newaxis], k.shape[1], axis=1)\n",
    "    elif k.ndim < w.ndim:\n",
    "        k = np.repeat(k[..., np.newaxis], w.shape[1], axis=1)\n",
    "\n",
    "    if theta.ndim != 1 or theta.size != w.shape[-1]:\n",
    "        raise ValueError\n",
    "\n",
    "    try:\n",
    "        ion = inputs[\"ion\"]\n",
    "    except KeyError:\n",
    "        ion = inputs[\"ions\"][0]\n",
    "    va = speeds.va_(inputs[\"B\"], inputs[\"n_i\"], ion=ion)\n",
    "\n",
    "    mag = ((w / (k * va)).to(u.dimensionless_unscaled).value) ** 2\n",
    "    theta = theta.to(u.radian).value\n",
    "\n",
    "    xnorm = mag * np.sin(theta)\n",
    "    ynorm = mag * np.cos(theta)\n",
    "\n",
    "    return np.array([xnorm, ynorm])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58f057ba",
   "metadata": {},
   "source": [
    "Now we can solve the Bellan solution for identical plasma parameters, in the first instance a cold plasma limit of $k_B T_e = 0.1$ eV and $k_B T_p = 0.1$ eV are assumed. In the second instance a warm plasma limit of $k_B T_e = 20$ eV and $k_B T_p = 10$ eV are assumed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c76faff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# defining all inputs\n",
    "base_inputs = {\n",
    "    \"k\": (2 * np.pi * u.rad) / (0.56547 * u.m),\n",
    "    \"theta\": np.linspace(0, 0.49 * np.pi, 50) * u.rad,\n",
    "    \"ion\": Particle(\"He-4 1+\"),\n",
    "    \"n_i\": 6.358e19 * u.m**-3,\n",
    "    \"B\": 400e-4 * u.T,\n",
    "}\n",
    "\n",
    "hot_inputs = {\n",
    "    **base_inputs,\n",
    "    \"T_e\": 20 * u.eV,\n",
    "    \"T_i\": 10 * u.eV,\n",
    "}\n",
    "\n",
    "cold_inputs = {\n",
    "    **base_inputs,\n",
    "    \"T_e\": 0.1 * u.eV,\n",
    "    \"T_i\": 0.1 * u.eV,\n",
    "}\n",
    "\n",
    "# calculating the solution from two fluid\n",
    "w_tf_hot = two_fluid(**hot_inputs)\n",
    "w_tf_cold = two_fluid(**cold_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b8f3534",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7902e8ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_tf_hot = {}\n",
    "for key, val in w_tf_hot.items():\n",
    "    plt_tf_hot[key] = norm_bellan_plot(**{**hot_inputs, \"w\": val})\n",
    "\n",
    "plt_tf_cold = {}\n",
    "for key, val in w_tf_cold.items():\n",
    "    plt_tf_cold[key] = norm_bellan_plot(**{**cold_inputs, \"w\": val})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93bab73c",
   "metadata": {},
   "source": [
    "Stix needs to recalculated using the Bellan inputs as a base."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fca167a",
   "metadata": {},
   "outputs": [],
   "source": [
    "stix_inputs = {**base_inputs, \"ions\": [base_inputs[\"ion\"]]}\n",
    "del stix_inputs[\"k\"]\n",
    "del stix_inputs[\"ion\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57469172",
   "metadata": {},
   "source": [
    "Bellan fixes $k$ and then calculates $ω$ for each mode and propagation angle $θ$. This means we cannot simply take solutions from Bellan and get corresponding $k$ values via Stix. In order to solve this problem we need to create a version of `stix()` that can be optimized by `scipy.optimize.root_scalar()`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "724eea0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# partially bind plasma parameter keywords to stix()\n",
    "_opt = stix_inputs.copy()\n",
    "del _opt[\"theta\"]\n",
    "stix_partial = functools.partial(stix, **_opt)\n",
    "\n",
    "\n",
    "def stix_optimize(w, theta, mode, k_expected):\n",
    "    \"\"\"Version of `stix` that can be optimized by `scipy.optimize.root_scalar`.\"\"\"\n",
    "    w = np.abs(w)\n",
    "    results = stix_partial(w=w * u.rad / u.s, theta=theta * u.rad).value\n",
    "\n",
    "    # only consider real and positive solutions\n",
    "    real_mask = np.where(np.imag(results) == 0, True, False)\n",
    "    pos_mask = np.where(np.real(results) > 0, True, False)\n",
    "    mask = np.logical_and(real_mask, pos_mask)\n",
    "\n",
    "    # get the correct k to compare\n",
    "    if np.count_nonzero(mask) == 1:\n",
    "        results = np.real(results[mask][0])\n",
    "    elif mode == \"fast_mode\":\n",
    "        # fast_mode has a larger phase velocity than\n",
    "        # the alfven_mode, thus take the smaller k-value\n",
    "        results = np.min(np.real(results[mask]))\n",
    "    else:  # alfven_mode\n",
    "        results = np.max(np.real(results[mask]))\n",
    "\n",
    "    return results - k_expected"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb4e6ace",
   "metadata": {},
   "source": [
    "Let's use the Cold case Bellan solution to solve for the Stix solution.  Note only the `fast_mode` and `slow_mode` solutions are being used to seed the Stix solution because the `acoustic_mode` disappears in the cold plasma limit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50439257",
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_arr = cold_inputs[\"theta\"].value\n",
    "k_expected = base_inputs[\"k\"].value\n",
    "k_stix = {}\n",
    "w_stix = {}\n",
    "for mode in (\"fast_mode\", \"alfven_mode\"):\n",
    "    w_arr = w_tf_cold[mode].value\n",
    "    k_stix[mode] = []\n",
    "    w_stix[mode] = []\n",
    "    for ii in range(w_arr.size):\n",
    "        w_guess = w_arr[ii]\n",
    "        _theta = theta_arr[ii]\n",
    "        result = scipy.optimize.root_scalar(\n",
    "            stix_optimize,\n",
    "            args=(_theta, mode, k_expected),\n",
    "            x0=w_guess,\n",
    "            x1=w_guess + 1e2,\n",
    "        )\n",
    "\n",
    "        # append the wavefrequency (result.root) that\n",
    "        # corresponded to stix() returning k_expected\n",
    "        w_stix[mode].append(np.real(result.root))\n",
    "\n",
    "        # double check and store the k-value\n",
    "        _k = stix(\n",
    "            **{\n",
    "                **stix_inputs,\n",
    "                \"w\": np.real(result.root) * u.rad / u.s,\n",
    "                \"theta\": theta_arr[ii] * u.rad,\n",
    "            }\n",
    "        ).value\n",
    "        real_mask = np.where(np.imag(_k) == 0, True, False)\n",
    "        pos_mask = np.where(np.real(_k) > 0, True, False)\n",
    "        mask = np.logical_and(real_mask, pos_mask)\n",
    "\n",
    "        _k = np.real(_k[mask])\n",
    "        mask = np.isclose(_k, base_inputs[\"k\"].value)\n",
    "        k_stix[mode].append(_k[mask][0])\n",
    "\n",
    "    k_stix[mode] = np.array(k_stix[mode])\n",
    "    w_stix[mode] = np.array(w_stix[mode])\n",
    "\n",
    "(\n",
    "    k_expected,\n",
    "    k_stix,\n",
    "    w_stix,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f213f0c",
   "metadata": {},
   "source": [
    "Create normalized arrays for plotting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62726cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_stix = {}\n",
    "for key, val in k_stix.items():\n",
    "    plt_stix[key] = norm_bellan_plot(\n",
    "        **{**stix_inputs, \"k\": val * u.rad / u.m, \"w\": w_stix[key] * u.rad / u.s}\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "221f86a8",
   "metadata": {},
   "source": [
    "Plot the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ad3ab00",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=[figwidth, figheight])\n",
    "\n",
    "mode = \"fast_mode\"\n",
    "plt.plot(\n",
    "    plt_stix[mode][0, ...],\n",
    "    plt_stix[mode][1, ...],\n",
    "    \"--\",\n",
    "    linewidth=3,\n",
    "    label=\"Fast Mode - Stix\",\n",
    ")\n",
    "plt.plot(\n",
    "    plt_tf_cold[mode][0, ...],\n",
    "    plt_tf_cold[mode][1, ...],\n",
    "    label=\"Fast Mode - Bellan Cold Plasma\",\n",
    ")\n",
    "plt.plot(\n",
    "    plt_tf_hot[mode][0, ...],\n",
    "    plt_tf_hot[mode][1, ...],\n",
    "    label=\"Fast Mode - Bellan Hot Plasma\",\n",
    ")\n",
    "\n",
    "mode = \"alfven_mode\"\n",
    "plt.plot(\n",
    "    plt_stix[mode][0, ...],\n",
    "    plt_stix[mode][1, ...],\n",
    "    \"--\",\n",
    "    linewidth=3,\n",
    "    label=\"Alfvén Mode - Stix\",\n",
    ")\n",
    "plt.plot(\n",
    "    plt_tf_cold[mode][0, ...],\n",
    "    plt_tf_cold[mode][1, ...],\n",
    "    label=\"Alfvén Mode - Bellan Cold Plasma\",\n",
    ")\n",
    "plt.plot(\n",
    "    plt_tf_hot[mode][0, ...],\n",
    "    plt_tf_hot[mode][1, ...],\n",
    "    label=\"Alfvén Mode - Bellan Hot Plasma\",\n",
    ")\n",
    "\n",
    "plt.legend(fontsize=fs)\n",
    "\n",
    "plt.xlabel(r\"$(ω / k v_A)^2 \\, \\sin θ$\", fontsize=fs)\n",
    "plt.ylabel(r\"$(ω / k v_A)^2 \\, \\cos θ$\", fontsize=fs)\n",
    "plt.xlim(0.0, 2.0)\n",
    "plt.ylim(0.0, 2.0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e85ead9-3057-41f2-af54-ddfa71ac0de3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
